import copy, os
import numpy as np
import torch
import torch.nn as nn
from Network.Dists.forward_mask import DiagGaussianForwardPadMaskNetwork 
from Network.Dists.forward_hot import DiagGaussianForwardPadHotNetwork 
from Network.Dists.forward_multi import DiagGaussianForwardMultiMaskNetwork 
from Network.Dists.inter_mask import InteractionMaskNetwork
from Network.Dists.inter_select import InteractionSelectionMaskNetwork
from Network.network_utils import reset_parameters
from Network.net_types import network_type
from Network.General.Factor.key_query import KeyQueryEncoder
from ACState.object_dict import ObjDict
from Record.file_management import save_to_pickle, load_from_pickle

def init_passive(args):
    all_train_forms = args.inter.train_forms + args.inter.pretrain_forms
    using_all_passive = "passive" in all_train_forms
    if not args.inter.use_active_as_passive and using_all_passive:
        passive_model_args = copy.deepcopy(args.interaction_net)
        passive_model_args.num_inputs = args.factor.object_dim * args.factor.num_objects
        passive_model_args.object_dim = args.factor.object_dim
        passive_model_args.first_obj_dim = args.factor.single_obj_dim
        passive_model_args.num_outputs = args.factor.single_obj_dim * args.factor.num_objects
        passive_model_args.output_dim = args.factor.single_obj_dim
        passive_model_args.net_type = args.inter.passive_net_type # should be from point, multi_mlp otherwise issues may arrise
        passive_model_args.aggregate_final = False
        passive_model_args.factor.pre_embed = passive_model_args.inter_net.shared_encoding
        return DiagGaussianForwardPadMaskNetwork(passive_model_args) # uses a mask distributional network, but the mask is never used
    return None # don't initialize wasted space

def init_single_passive(args):
    all_train_forms = args.inter.train_forms + args.inter.pretrain_forms
    using_single_passive = "single_passive" in all_train_forms
    passive_single_models = dict()
    for name in args.inter.train_names:
        passive_single_models[name] = None
    if not(args.inter.use_all_as_single or args.inter.use_active_as_passive) and using_single_passive:
        passive_model_args = copy.deepcopy(args.interaction_net)
        passive_model_args.num_inputs = args.factor.object_dim # TODO: does not support MLPs for multiobject
        passive_model_args.object_dim = args.factor.object_dim
        passive_model_args.num_outputs = args.factor.single_obj_dim
        passive_model_args.output_dim = args.factor.single_obj_dim
        passive_model_args.net_type = args.inter.passive_net_type # should be from point otherwise issues may arise
        passive_model_args.aggregate_final = True
        passive_model_args.factor.pre_embed = passive_model_args.inter_net.shared_encoding
        for name in args.inter.train_names:
            pma = copy.deepcopy(passive_model_args)
            # print(args.factor.named_first_obj_dim, args.factor.name_idxes)
            pma.factor.first_obj_dim = args.factor.named_first_obj_dim[name]
            pma.factor.name_idx = args.factor.name_idxes[name]
            passive_single_models[name] = DiagGaussianForwardPadMaskNetwork(pma) # uses a mask distributional network, but the mask is never used
    return passive_single_models


def init_active_inter_all(args):
    all_train_forms = args.inter.train_forms + args.inter.pretrain_forms
    using_all_passive = ("passive" in all_train_forms and args.inter.use_active_as_passive) or ("all" in all_train_forms) or ("all_mask" in all_train_forms) or ("all_rand_mask" in all_train_forms)
    if using_all_passive:   
        return init_active_inter(args, args.factor.first_obj_dim)
    return None, None

def init_active_inter_full(args):
    all_train_forms = args.inter.train_forms + args.inter.pretrain_forms
    # as long as we are not using all as single, and doing one of the possible training values
    active_models, inter_models = dict(), dict()
    using_full = (("single_passive" in all_train_forms and args.inter.use_active_as_passive and not args.inter.use_all_as_single) or 
                  ("pair" in all_train_forms and args.inter.use_full_as_pair) or 
                  ("full" in all_train_forms and not args.inter.use_all_as_single) or 
                  ("mask" in all_train_forms and not args.inter.use_all_as_single) or 
                  ("rand_mask" in all_train_forms and not args.inter.use_all_as_single) or 
                  ("all_mask" in all_train_forms and not args.inter.use_all_as_single) 
                #   ("binaries" in all_train_forms and not args.inter.use_all_as_single) or # TODO: binaries and null binaries could only initialize the inter model, which is not supported
                #   ("null_bin" in all_train_forms and not args.inter.use_all_as_single)
                  )
    for name in args.inter.train_names:
        active_models[name], inter_models[name] = None, None
    if using_full:
        for name in args.inter.train_names:
            active, inter = init_active_inter(args, args.factor.named_first_obj_dim[name])
            active_models[name], inter_models[name] = active, inter
    return active_models, inter_models

def init_pair(args, extractor):
    pair_model_args = copy.deepcopy(args.interaction_net)
    pair_model_args.num_outputs = args.factor.single_obj_dim
    pair_model_args.object_dim = args.factor.object_dim
    pair_model_args.output_dim = args.factor.single_obj_dim
    pair_model_args.factor.first_obj_dim = args.factor.single_obj_dim
    pair_model_args.aggregate_final = True
    pair_model_args.factor.query_aggregate = True
    pair_model_args.factor.pre_embed = pair_model_args.inter_net.shared_encoding
    pair_models = dict()
    # as long as we are not using all as single, and doing one of the possible training values
    all_train_forms = args.inter.train_forms + args.inter.pretrain_forms
    using_pair = ("pair" in all_train_forms and not (args.inter.use_all_as_single or args.inter.use_full_as_pair))
    if using_pair:
        for pair in args.inter.pair_names:
            pma = copy.deepcopy(pair_model_args)
            pma.num_inputs = args.factor.object_dim * len(pair.split('->')[0].split("|"))
            pma.factor.name_idx = extractor.get_index(pair.split('->')[0].split("|"))
            # pair_model_args.factor.name_idx = -1
            pair_models[pair] = DiagGaussianForwardPadMaskNetwork(pma)
    return pair_models


def init_active_inter(args, first_obj_dim):
    # the same initialization is used for all or for full, full just initializes more of them
    # TODO: implement shared key query, otherwise this can be separated into two functions
    # if args.interaction_net.shared_key_query:
        
    interaction_model_args = copy.deepcopy(args.interaction_net)
    interaction_model_args.num_inputs = args.factor.object_dim * args.factor.num_objects
    interaction_model_args.object_dim = args.factor.object_dim
    interaction_model_args.num_outputs = 1 # TODO: could create issues if using an MLP at the end
    interaction_model_args.output_dim = 1
    interaction_model_args.mask_attn.return_mask = False
    interaction_model_args.softmax_output = False
    interaction_model_args.aggregate_final = False
    interaction_model_args.factor.query_aggregate = False
    interaction_model_args.factor_net.no_decode = False
    interaction_model_args.factor.first_obj_dim = first_obj_dim
    interaction_model = (InteractionSelectionMaskNetwork(interaction_model_args) if args.masking.selection_mask else InteractionMaskNetwork(interaction_model_args)) if not args.masking.selection_mask else None

    active_model_args = copy.deepcopy(args.interaction_net)
    active_model_args.num_inputs = args.factor.object_dim * args.factor.num_objects
    active_model_args.num_outputs = args.factor.single_obj_dim * args.factor.num_objects
    active_model_args.object_dim = args.factor.object_dim
    active_model_args.output_dim = args.factor.single_obj_dim
    active_model_args.factor.query_aggregate = True
    active_model_args.factor.pre_embed = active_model_args.inter_net.shared_encoding
    active_model_args.factor.first_obj_dim = first_obj_dim
    if args.active.use_cluster: active_model = DiagGaussianForwardPadHotNetwork(active_model_args) 
    elif args.active.use_population: active_model = DiagGaussianForwardMultiMaskNetwork(active_model_args) 
    else: active_model = DiagGaussianForwardPadMaskNetwork(active_model_args)

    return active_model, interaction_model

def set_net_parameters(base_model, args, params):
    # sets parameters inside networks
    def set_params(net_name, network):
        # implement all the parameter setting logic here TODO: more setting can happen here
        if hasattr(network, "key_query_encoder"): network.key_query_encoder.soft_mask_param = params.soft_mask_param
        # print(params.reset_inter, params.reset_active, net_name)
        if params.reset_inter and net_name.find("inter") != -1: reset_parameters(network, args.active.resetting.reset_form, n_layers=args.active.resetting.reset_layers)
        if params.reset_active and net_name in ["full", "all"]: reset_parameters(network, args.active.resetting.reset_form, n_layers=args.active.resetting.reset_layers)
    
    if base_model.all_passive_model is not None: set_params("all_passive", base_model.all_passive_model)
    if base_model.all_model is not None: set_params("all", base_model.all_model)
    if base_model.all_inter_model is not None: set_params("all_inter", base_model.all_inter_model)
    if base_model.single_passive_models is not None: set_params("single_passive", base_model.single_passive_models)
    if base_model.pair_models is not None: set_params("pair", base_model.pair_models)
    if base_model.full_models is not None: set_params("full", base_model.full_models)
    if base_model.inter_models is not None: set_params("inter", base_model.inter_models)

def save_model(model, pth):
    # saves all the modules TODO: in entirety right now, possibly replace with state_dict
    model.cpu()
    try:
        os.makedirs(pth)
    except OSError:
        pass
    if model.all_passive_model is not None: 
        torch.save(model.all_passive_model, os.path.join(pth, "all_passive_model.pth"))
        # save_to_pickle (os.path.join(pth, "all_passive_model.pth"), model.all_passive_model,)
    if model.all_model is not None:
        torch.save(model.all_model, os.path.join(pth, "all_model.pth"))
        # save_to_pickle (os.path.join(pth, "all_model.pth"), model.all_model,)
    if model.all_inter_model is not None: 
        torch.save(model.all_inter_model, os.path.join(pth, "all_inter_model.pth"))
        # save_to_pickle (os.path.join(pth, "all_inter_model.pth"), model.all_inter_model,)
    if model.single_passive_models is not None: 
        for n in model.train_names:
            torch.save(model.single_passive_models[n], os.path.join(pth, n + "_single_passive_model.pth"))
            # save_to_pickle(os.path.join(pth, n + "_single_passive_model.pth"), model.single_passive_models[n])
    if model.pair_models is not None:
        for n in model.pair_names:
            torch.save(model.pair_models[n], os.path.join(pth, n + "pair_model.pth"))
            # save_to_pickle(os.path.join(pth, n + "pair_model.pth"), model.pair_models[n])
    if model.full_models is not None:
        for n in model.train_names:
            torch.save(model.full_models[n], os.path.join(pth, n + "full_model.pth"))
            # save_to_pickle(os.path.join(pth, n + "full_model.pth"), model.full_models[n])
    if model.inter_models is not None:
        for n in model.train_names:
            torch.save(model.inter_models[n], os.path.join(pth, n + "inter_model.pth"))
            # save_to_pickle(os.path.join(pth, n + "inter_model.pth"), model.inter_models[n])

def load_model(model, pth, device="cpu"):
    print(pth)
    if len(pth) == 0: return model
    print("loading model from, ", pth)
    # torch.serialization.add_safe_globals([DiagGaussianForwardPadMaskNetwork, set, nn.modules.activation.LeakyReLU, nn.modules.linear.Identity] +
    #                                      list(network_type.values()) +
    #                                      [KeyQueryEncoder, ObjDict, np._core.multiarray._reconstruct, np.ndarray])  # TODO: probably haven't added all safe globals
    if model.all_passive_model is not None: 
        model.all_passive_model = torch.load(os.path.join(pth, "all_passive_model.pth"), weights_only=False)
    if model.all_model is not None:
        model.all_model = torch.load(os.path.join(pth, "all_model.pth"), weights_only=False)
    if model.all_inter_model is not None: 
        model.all_inter_model = torch.load(os.path.join(pth, "all_inter_model.pth"), weights_only=False)
    if model.single_passive_models is not None:
        for n in model.train_names:
            model.single_passive_models[n] = torch.load( os.path.join(pth, n + "_single_passive_model.pth"), weights_only=False)
    if model.pair_models is not None:
        for n in model.pair_names:
            model.pair_models[n] = torch.load( os.path.join(pth, n + "pair_model.pth"), weights_only=False)
    if model.full_models is not None:
        for n in model.train_names:
            model.full_models[n] = torch.load( os.path.join(pth, n + "full_model.pth"), weights_only=False)
    if model.inter_models is not None:
        for n in model.train_names:
            model.inter_models[n] = torch.load( os.path.join(pth, n + "inter_model.pth"), weights_only=False)
    model.assign_module_from_model()
    if device == "cpu":
        model = model.cpu()
    else:
        model = model.cuda(device=device)
    return model